1. Import Library

In [1]:
# # import library for usage of GPU
# import os
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]="0"
In [2]:
# import library
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from glob import glob
from PIL import Image
from scipy import misc
from sklearn.model_selection import train_test_split

2. Load Data

In [3]:
# load in-focus and out-of-focus SEM images
train_images_level_1 = glob('./data_files/out-of-focus_level_1/train/*.*')
train_labels_level_1 = glob('./data_files/infocus_level_1/train/*.*')
test_images_level_1 = glob('./data_files/out-of-focus_level_1/test/*.*')
test_labels_level_1 = glob('./data_files/infocus_level_1/test/*.*')
train_images_level_1.sort()
train_labels_level_1.sort()
test_images_level_1.sort()
test_labels_level_1.sort()

train_images_level_2 = glob('./data_files/out-of-focus_level_2/train/*.*')
train_labels_level_2 = glob('./data_files/infocus_level_2/train/*.*')
test_images_level_2 = glob('./data_files/out-of-focus_level_2/test/*.*')
test_labels_level_2 = glob('./data_files/infocus_level_2/test/*.*')
train_images_level_2.sort()
train_labels_level_2.sort()
test_images_level_2.sort()
test_labels_level_2.sort()

train_images_level_3 = glob('./data_files/out-of-focus_level_3/train/*.*')
train_labels_level_3 = glob('./data_files/infocus_level_3/train/*.*')
test_images_level_3 = glob('./data_files/out-of-focus_level_3/test/*.*')
test_labels_level_3 = glob('./data_files/infocus_level_3/test/*.*')
train_images_level_3.sort()
train_labels_level_3.sort()
test_images_level_3.sort()
test_labels_level_3.sort()

# split the SEM images into train, validation, and test data
train_images_level_1, valid_images_level_1, train_labels_level_1, valid_labels_level_1 = train_test_split(train_images_level_1, train_labels_level_1, test_size=0.10)
train_images_level_2, valid_images_level_2, train_labels_level_2, valid_labels_level_2 = train_test_split(train_images_level_2, train_labels_level_2, test_size=0.10)
train_images_level_3, valid_images_level_3, train_labels_level_3, valid_labels_level_3 = train_test_split(train_images_level_3, train_labels_level_3, test_size=0.10)

train_images = np.concatenate((train_images_level_1, train_images_level_2, train_images_level_3))
train_labels = np.concatenate((train_labels_level_1, train_labels_level_2, train_labels_level_3))
valid_images = np.concatenate((valid_images_level_1, valid_images_level_2, valid_images_level_3))
valid_labels = np.concatenate((valid_labels_level_1, valid_labels_level_2, valid_labels_level_3))
test_images = np.concatenate((test_images_level_1, test_images_level_2, test_images_level_3))
test_labels = np.concatenate((test_labels_level_1, test_labels_level_2, test_labels_level_3))

2.1. Batch Maker

In [4]:
# data augmentation technique applied for MRN powered by DA
def cutblur(image, label):
    cut_size = int(crop_size / 4)
    
    x_size = image.shape[0]
    y_size = image.shape[1]
    start_x = np.random.randint(x_size-cut_size)
    start_y = np.random.randint(y_size-cut_size)
    
    cut_label = label[start_x:start_x+cut_size, start_y:start_y+cut_size].copy()
    
    image[start_x:start_x+cut_size, start_y:start_y+cut_size] = cut_label
    
    return image, label
In [5]:
# implement batch maker used for training phase
crop_size = 256
factor = 2

current_batch = 0
def train_batch_maker(batch_size):
    global current_batch
    global train_images
    global train_labels
    
    if len(train_images) - current_batch >= batch_size:
        batch_train_images = train_images[current_batch:current_batch+batch_size]
        batch_train_labels = train_labels[current_batch:current_batch+batch_size]
        current_batch += batch_size
    else :
        idx_train = np.arange(len(train_images))
        np.random.shuffle(idx_train)
        batch_train_images = train_images[idx_train]
        batch_train_labels = train_labels[idx_train]

        current_batch = 0
        batch_train_images = train_images[current_batch:current_batch+batch_size]
        batch_train_labels = train_labels[current_batch:current_batch+batch_size]

    train_images_coarsest = []
    train_images_intermediate = []
    train_images_finer = []
    train_labels_coarsest = []
    train_labels_intermediate = []
    train_labels_finer = []

    for image, label in zip(batch_train_images, batch_train_labels):
        temp_image = Image.open(image)
        temp_image = np.array(temp_image)
        temp_label = Image.open(label)
        temp_label = np.array(temp_label)

        x_size = temp_image.shape[0]
        y_size = temp_image.shape[1]
        start_x = np.random.randint(x_size-crop_size)
        start_y = np.random.randint(y_size-crop_size)

        temp_image = temp_image[start_x:start_x+crop_size, start_y:start_y+crop_size]
        temp_label = temp_label[start_x:start_x+crop_size, start_y:start_y+crop_size]

        temp_image, temp_label = cutblur(temp_image, temp_label)
        
        temp_image_finer = temp_image.copy()[:,:,np.newaxis]
        temp_image_intermediate = misc.imresize(temp_image, 1.0/factor, interp = 'bicubic')[:,:,np.newaxis]
        temp_image_coarsest = misc.imresize(temp_image, 1.0/(factor**2), interp = 'bicubic')[:,:,np.newaxis]
        temp_label_finer = temp_label.copy()[:,:,np.newaxis]
        temp_label_intermediate = misc.imresize(temp_label, 1.0/factor, interp = 'bicubic')[:,:,np.newaxis]
        temp_label_coarsest = misc.imresize(temp_label, 1.0/(factor**2), interp = 'bicubic')[:,:,np.newaxis]

        train_images_coarsest.append(temp_image_coarsest / 255.0)
        train_images_intermediate.append(temp_image_intermediate / 255.0)
        train_images_finer.append(temp_image_finer / 255.0)
        train_labels_coarsest.append(temp_label_coarsest / 255.0)
        train_labels_intermediate.append(temp_label_intermediate / 255.0)
        train_labels_finer.append(temp_label_finer / 255.0)

    train_images_coarsest = np.array(train_images_coarsest)
    train_images_intermediate = np.array(train_images_intermediate)
    train_images_finer = np.array(train_images_finer)
    train_labels_coarsest = np.array(train_labels_coarsest)
    train_labels_intermediate = np.array(train_labels_intermediate)
    train_labels_finer = np.array(train_labels_finer)
    
    return train_images_coarsest, train_images_intermediate, train_images_finer, train_labels_coarsest, train_labels_intermediate, train_labels_finer 
In [6]:
# show samples of training data for MRN powered by DA
train_images_coarsest, train_images_intermediate, train_images_finer, train_labels_coarsest, train_labels_intermediate, train_labels_finer = train_batch_maker(5)

print('Input data for coarsest scale network')
plt.figure(figsize=(20,20))
plt.imshow(train_images_coarsest[0,:,:,0], cmap = 'gray')
plt.show()

print('Input data for intermediate scale network')
plt.figure(figsize=(20,20))
plt.imshow(train_images_intermediate[0,:,:,0], cmap = 'gray')
plt.show()

print('Input data for finer scale network')
plt.figure(figsize=(20,20))
plt.imshow(train_images_finer[0,:,:,0], cmap = 'gray')
plt.show()

print('Ground truth data for coarsest scale network')
plt.figure(figsize=(20,20))
plt.imshow(train_labels_coarsest[0,:,:,0], cmap = 'gray')
plt.show()

print('Ground truth data for intermediate scale network')
plt.figure(figsize=(20,20))
plt.imshow(train_labels_intermediate[0,:,:,0], cmap = 'gray')
plt.show()

print('Ground truth data for finer scale network')
plt.figure(figsize=(20,20))
plt.imshow(train_labels_finer[0,:,:,0], cmap = 'gray')
plt.show()
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:49: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:50: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:52: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:53: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
Input data for coarsest scale network
Input data for intermediate scale network
Input data for finer scale network
Ground truth data for coarsest scale network
Ground truth data for intermediate scale network
Ground truth data for finer scale network

2.2. Validation Data Preprocessing

In [7]:
# Preprocess valid data used for validation phase
valid_images_coarsest = []
valid_images_intermediate = []
valid_images_finer = []
valid_labels_coarsest = []
valid_labels_intermediate = []
valid_labels_finer = []

for image, label in zip(valid_images, valid_labels):
    temp_image = Image.open(image)
    temp_image = np.array(temp_image)
    temp_label = Image.open(label)
    temp_label = np.array(temp_label)
    
    start_x = 250
    start_y = 250
    
    temp_image = temp_image[start_x:start_x+crop_size, start_y:start_y+crop_size]
    temp_label = temp_label[start_x:start_x+crop_size, start_y:start_y+crop_size]
    
    temp_image, temp_label = cutblur(temp_image, temp_label)
    
    temp_image_finer = temp_image.copy()[:,:,np.newaxis]
    temp_image_intermediate = misc.imresize(temp_image, 1.0/factor, interp = 'bicubic')[:,:,np.newaxis]
    temp_image_coarsest = misc.imresize(temp_image, 1.0/(factor**2), interp = 'bicubic')[:,:,np.newaxis]
    temp_label_finer = temp_label.copy()[:,:,np.newaxis]
    temp_label_intermediate = misc.imresize(temp_label, 1.0/factor, interp = 'bicubic')[:,:,np.newaxis]
    temp_label_coarsest = misc.imresize(temp_label, 1.0/(factor**2), interp = 'bicubic')[:,:,np.newaxis]
    
    valid_images_coarsest.append(temp_image_coarsest / 255.0)
    valid_images_intermediate.append(temp_image_intermediate / 255.0)
    valid_images_finer.append(temp_image_finer / 255.0)
    valid_labels_coarsest.append(temp_label_coarsest / 255.0)
    valid_labels_intermediate.append(temp_label_intermediate / 255.0)
    valid_labels_finer.append(temp_label_finer / 255.0)

valid_images_coarsest = np.array(valid_images_coarsest)
valid_images_intermediate = np.array(valid_images_intermediate)
valid_images_finer = np.array(valid_images_finer)
valid_labels_coarsest = np.array(valid_labels_coarsest)
valid_labels_intermediate = np.array(valid_labels_intermediate)
valid_labels_finer = np.array(valid_labels_finer)
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:24: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:25: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:27: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:28: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.

3. Implement MRN Powered by DA

3.1. MRN Networks

In [8]:
# define placeholders
x_coarsest = tf.placeholder(tf.float32, [None, None, None, 1], name='x_coarsest')
x_intermediate = tf.placeholder(tf.float32, [None, None, None, 1], name='x_intermediate')
x_finer = tf.placeholder(tf.float32, [None, None, None, 1], name='x_finer')

y_coarsest = tf.placeholder(tf.float32, [None, None, None, 1], name='y_coarsest')
y_intermediate = tf.placeholder(tf.float32, [None, None, None, 1], name='y_intermediate')
y_finer = tf.placeholder(tf.float32, [None, None, None, 1], name='y_finer')
In [9]:
# define modules
def ResidualBlock(x, kernel_size, filters, strides = 1):
    skip = x
    x = tf.layers.conv2d(x, 
                         kernel_size = kernel_size,
                         filters = filters,
                         strides = strides,
                         padding = 'same',
                         use_bias = False)
    x = tf.contrib.keras.layers.PReLU(shared_axes = [1,2])(x)
    x = tf.layers.conv2d(x,
                         kernel_size = kernel_size,
                         filters = filters,
                         strides = strides,
                         padding = 'same',
                         use_bias = False)
    x = x + skip
    return x

def Upsample2xBlock(x, kernel_size, filters, name, strides = 1):
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE) as scope:
        x = tf.layers.conv2d(x,
                             kernel_size = kernel_size,
                             filters = filters,
                             strides = strides,
                             padding = 'same')
        x = tf.depth_to_space(x, 2)
        x = tf.nn.relu(x)
        return x
In [10]:
# define subnetworks
def resnet_coarsest(x, num_blocks):
    with tf.variable_scope('resnet_coarsest', reuse=tf.AUTO_REUSE) as scope:
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 64,
                             strides = 1,
                             padding = 'same')
        x = tf.contrib.keras.layers.PReLU(shared_axes = [1,2])(x)
        skip = x

        for i in range(num_blocks):
            x = ResidualBlock(x, kernel_size = 5, filters = 64, strides = 1)
            
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 64,
                             strides = 1,
                             padding = 'same',
                             use_bias = False)
        x = x + skip
        
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 1,
                             strides = 1,
                             padding = 'same',
                             name = 'forward')
        return tf.nn.sigmoid(x)

def resnet_intermediate(x, num_blocks):
    with tf.variable_scope('resnet_intermediate', reuse=tf.AUTO_REUSE) as scope:
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 64,
                             strides = 1,
                             padding = 'same')
        x = tf.contrib.keras.layers.PReLU(shared_axes = [1,2])(x)
        skip = x

        for i in range(num_blocks):
            x = ResidualBlock(x, kernel_size = 5, filters = 64, strides = 1)
            
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 64,
                             strides = 1,
                             padding = 'same',
                             use_bias = False)
        x = x + skip
        
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 1,
                             strides = 1,
                             padding = 'same',
                             name = 'forward')
        return tf.nn.sigmoid(x)
    
def resnet_finer(x, num_blocks):
    with tf.variable_scope('resnet_finer', reuse=tf.AUTO_REUSE) as scope:
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 64,
                             strides = 1,
                             padding = 'same')
        x = tf.contrib.keras.layers.PReLU(shared_axes = [1,2])(x)
        skip = x

        for i in range(num_blocks):
            x = ResidualBlock(x, kernel_size = 5, filters = 64, strides = 1)
            
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 64,
                             strides = 1,
                             padding = 'same',
                             use_bias = False)
        x = x + skip
        
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 1,
                             strides = 1,
                             padding = 'same',
                             name = 'forward')
        return tf.nn.sigmoid(x)

3.2. Loss & Optimizer

In [11]:
# networks
## coarsest level network
refocus_coarsest = resnet_coarsest(x_coarsest, 16)
refocus_coarsest_upconv = Upsample2xBlock(refocus_coarsest, kernel_size = 3, filters = 4, name = 'upconv_for_intermediate')
refocus_coarsest_upconv_concat = tf.concat((refocus_coarsest_upconv, x_intermediate), axis = 3)

## intermediate level network
refocus_intermediate = resnet_intermediate(refocus_coarsest_upconv_concat, 16)
refocus_intermediate_upconv = Upsample2xBlock(refocus_intermediate, kernel_size = 3, filters = 4, name = 'upconv_for_finer')
refocus_intermediate_upconv_concat = tf.concat((refocus_intermediate_upconv, x_finer), axis = 3)

## finer level network
refocus_finer = resnet_finer(refocus_intermediate_upconv_concat, 16)

# loss
loss_coarsest = tf.reduce_mean(tf.abs(y_coarsest - refocus_coarsest))
loss_intermediate = tf.reduce_mean(tf.abs(y_intermediate - refocus_intermediate))
loss_finer = tf.reduce_mean(tf.abs(y_finer - refocus_finer))

# learning rate
LR = 0.00005
global_step = tf.contrib.framework.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(LR, global_step, 50000, 0.1, staircase = False)
incr_global_step = tf.assign(global_step, global_step + 1)

# variable list
var_coarsest = [var for var in tf.get_collection('trainable_variables') if 'resnet_coarsest' in var.name]
var_intermediate = [var for var in tf.get_collection('trainable_variables') if 'resnet_intermediate' in var.name or 'upconv_for_intermediate' in var.name]
var_finer = [var for var in tf.get_collection('trainable_variables') if 'resnet_finer' in var.name or 'upconv_for_finer' in var.name]

# optimizer
optm_coarsest = tf.train.AdamOptimizer(learning_rate).minimize(loss_coarsest, var_list = var_coarsest)
optm_intermediate = tf.train.AdamOptimizer(learning_rate).minimize(loss_intermediate, var_list = var_intermediate)
optm_finer = tf.train.AdamOptimizer(learning_rate).minimize(loss_finer, var_list = var_finer)
WARNING: Logging before flag parsing goes to stderr.
W1104 00:38:34.138664 140485341157184 deprecation.py:323] From <ipython-input-10-30d5ef84962a>:8: conv2d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.keras.layers.Conv2D` instead.
W1104 00:38:34.145472 140485341157184 deprecation.py:506] From /mnt/disk1/project/.env/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W1104 00:38:34.848397 140485341157184 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

W1104 00:38:36.355923 140485341157184 deprecation.py:323] From <ipython-input-11-110f8e96db38>:22: get_or_create_global_step (from tensorflow.contrib.framework.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Please switch to tf.train.get_or_create_global_step

3.3. Optimization

In [12]:
# training parameters
n_iter = 50000
n_prt = 100
n_batch = 5
In [ ]:
# open a session for training
sess = tf.Session()
sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

# optimize a model during n_iter
criteria = 10
loss_train_record = []
loss_valid_record = []
for epoch in range(n_iter):
    train_images_coarsest, train_images_intermediate, train_images_finer, train_labels_coarsest, train_labels_intermediate, train_labels_finer = train_batch_maker(n_batch)

    sess.run([optm_coarsest, optm_intermediate, optm_finer], feed_dict = {x_coarsest: train_images_coarsest, 
                                                                          x_intermediate: train_images_intermediate,
                                                                          x_finer: train_images_finer, 
                                                                          y_coarsest: train_labels_coarsest, 
                                                                          y_intermediate: train_labels_intermediate, 
                                                                          y_finer: train_labels_finer})
    sess.run(incr_global_step)
    criteria_temp = sess.run(loss_finer, feed_dict = {x_coarsest: valid_images_coarsest, 
                                                      x_intermediate: valid_images_intermediate,
                                                      x_finer: valid_images_finer, 
                                                      y_coarsest: valid_labels_coarsest, 
                                                      y_intermediate: valid_labels_intermediate, 
                                                      y_finer: valid_labels_finer})

    if criteria > criteria_temp:
        criteria = criteria_temp
        saver.save(sess, './model/MRN.ckpt')

    if epoch % n_prt == 0:
        loss_train = sess.run(loss_finer, feed_dict = {x_coarsest: train_images_coarsest, 
                                                       x_intermediate: train_images_intermediate,
                                                       x_finer: train_images_finer, 
                                                       y_coarsest: train_labels_coarsest, 
                                                       y_intermediate: train_labels_intermediate, 
                                                       y_finer: train_labels_finer})
        loss_valid = sess.run(loss_finer, feed_dict = {x_coarsest: valid_images_coarsest, 
                                                       x_intermediate: valid_images_intermediate,
                                                       x_finer: valid_images_finer, 
                                                       y_coarsest: valid_labels_coarsest, 
                                                       y_intermediate: valid_labels_intermediate, 
                                                       y_finer: valid_labels_finer})
        loss_train_record.append(loss_train)
        loss_valid_record.append(loss_valid)
        
        print('Epoch:', '%04d' % epoch, 'loss_train: {:.4}'.format(loss_train), 'loss_valid: {:.4}'.format(loss_valid))
        refocus_img = sess.run(refocus_finer, feed_dict = {x_coarsest: train_images_coarsest[:1], 
                                                           x_intermediate: train_images_intermediate[:1],
                                                           x_finer: train_images_finer[:1],
                                                           y_coarsest: train_labels_coarsest[:1],
                                                           y_intermediate: train_labels_intermediate[:1],
                                                           y_finer: train_labels_finer[:1]})
        
        plt.figure(figsize = (5,5))
        plt.imshow(refocus_img[0,:,:,0], cmap = 'gray')
        plt.axis('off')
        plt.show()

3.5. Best Model Loading

In [13]:
# load the best model
save_file = './model/MRN.ckpt'
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, save_file)
W1104 00:38:40.437350 140485341157184 deprecation.py:323] From /mnt/disk1/project/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.

4. Evaluate the Trained MRN Model

4.1. Test Images

In [14]:
# evaluate the test images of level 1,2, and 3 by MRN powered by DA
n = np.random.randint(len(test_images))

test_image = Image.open(test_images[n])
test_image = np.array(test_image)
test_label = Image.open(test_labels[n])
test_label = np.array(test_label)

if (test_image.shape[0] / factor**2) % 1 != 0 or (test_image.shape[1] / factor**2) % 1 != 0:
    new_x_shape = int(test_image.shape[0] / factor**2) * factor**2
    new_y_shape = int(test_image.shape[1] / factor**2) * factor**2
    test_image = test_image[:new_x_shape,:new_y_shape]
    test_label = test_label[:new_x_shape,:new_y_shape]

test_image_finer = test_image.copy()[np.newaxis,:,:,np.newaxis] / 255
test_image_intermediate = misc.imresize(test_image, 1.0/factor, interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255
test_image_coarsest = misc.imresize(test_image, 1.0/(factor**2), interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255

test_label_finer = test_label.copy()[np.newaxis,:,:,np.newaxis] / 255
test_label_intermediate = misc.imresize(test_label, 1.0/factor, interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255
test_label_coarsest = misc.imresize(test_label, 1.0/(factor**2), interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255

refocus_img = sess.run(refocus_finer, feed_dict = {x_coarsest: test_image_coarsest, 
                                                 x_intermediate: test_image_intermediate,
                                                 x_finer: test_image_finer, 
                                                 y_coarsest: test_label_coarsest, 
                                                 y_intermediate: test_label_intermediate, 
                                                 y_finer: test_label_finer})

# show the out-of-focuse (input), refocus (output), and in-focus (label) images
plt.figure(figsize = (20,20))
plt.imshow(refocus_img[0,:,:,0], cmap = 'gray')
plt.axis('off')
plt.show()

plt.figure(figsize = (20,20))
plt.imshow(test_image[:,:], cmap = 'gray')
plt.axis('off')
plt.show()

plt.figure(figsize = (20,20))
plt.imshow(test_label[:,:], cmap = 'gray')
plt.axis('off')
plt.show()
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:16: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
  app.launch_new_instance()
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:17: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:20: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:21: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.

4.2. Non-uniformly Defocused Image

In [15]:
# refocus non-uniformly defocused image by MRN powered by DA
non_uniform_refocus_list = []
non_uniform_abs_residual_intensity_map_list = []

for i in range(58538 // 5000 + 1):
    
    file_name = './data_files/non_uniformly_defocused_image.png'

    test_image = Image.open(file_name)
    try:
        test_image = np.array(test_image)[:,i * 5000:(i + 1) * 5000 ]
    except:
        test_image = np.array(test_image)[:,i * 5000:]
    
    if (test_image.shape[0] / factor**2) % 1 != 0 or (test_image.shape[1] / factor**2) % 1 != 0:
        new_x_shape = int(test_image.shape[0] / factor**2) * factor**2
        new_y_shape = int(test_image.shape[1] / factor**2) * factor**2
        test_image = test_image[:new_x_shape,:new_y_shape]

    test_image_finer = test_image.copy()[np.newaxis,:,:,np.newaxis] / 255
    test_image_intermediate = misc.imresize(test_image, 1.0/factor, interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255
    test_image_coarsest = misc.imresize(test_image, 1.0/(factor**2), interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255
    
    refocus_img = sess.run(refocus_finer, feed_dict = {x_coarsest: test_image_coarsest, 
                                                     x_intermediate: test_image_intermediate,
                                                     x_finer: test_image_finer})
    non_uniform_refocus_list.append(refocus_img[0,:,:,0])
    non_uniform_abs_residual_intensity_map_list.append(np.abs(refocus_img[0,:,:,0] - test_image[:,:] / 255))

for i in range(len(non_uniform_refocus_list)):
    if i == 0: 
        non_uniform_refocus = non_uniform_refocus_list[i]
    else:
        non_uniform_refocus = np.hstack((non_uniform_refocus, non_uniform_refocus_list[i]))

for i in range(len(non_uniform_abs_residual_intensity_map_list)):
    if i == 0: 
        non_uniform_abs_residual_intensity_map = non_uniform_abs_residual_intensity_map_list[i]
    else:
        non_uniform_abs_residual_intensity_map = np.hstack((non_uniform_abs_residual_intensity_map, non_uniform_abs_residual_intensity_map_list[i]))

# show the refocused image & residual intensity map
plt.figure(figsize = (20,20))
plt.imshow(non_uniform_refocus, 'gray')
plt.axis('off')
plt.show()

plt.figure(figsize = (20,20))
plt.imshow(non_uniform_abs_residual_intensity_map)
plt.axis('off')
plt.show()
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:21: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
/mnt/disk1/project/.env/lib/python3.6/site-packages/ipykernel_launcher.py:22: DeprecationWarning: `imresize` is deprecated!
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.